Distributions.jl icon indicating copy to clipboard operation
Distributions.jl copied to clipboard

Sampling from Distributions.product_distribution allocates (a lot)

Open JADekker opened this issue 10 months ago • 8 comments

Hi, I noticed that when sampling a matrix of values from a Distributions.product_distribution, a lot of allocations are made (in my example, more allocations than sampled values!) if I call rand(..., N) directly on a product_distributions with different distributions. However, my naive implementations (rand_quick, rand_quick2) beat the default implementation by roughly a factor 10 on my device and cut allocations to a constant amount in the matrix case. rand_quick also provides a modest speed-up on my device in the vector case. Am I misusing product_distribution here, or is there room for improvement here? See this discourse thread for my initial question that led to this issue: https://discourse.julialang.org/t/product-distribution-allocates-a-lot/126771/2.

using Distributions, BenchmarkTools, Random
Random.seed!(42)

function rand_quick(d::Product) 
    N_out = Vector{Float64}(undef, length(d.v))
    for (i, dist) in enumerate(d.v)
        N_out[i] = rand(dist)
    end
    return N_out
end
rand_quick2(d::Product) = [rand(dist) for dist in d.v]
function rand_quick(d::Product, N::Int64)
    N_out = Matrix{Float64}(undef, N, length(d.v))
    for (i, dist) in enumerate(d.v)
        N_out[:, i] .= rand(dist, N)
    end
    return permutedims(N_out)
end
rand_quick2(d::Product, N::Int64) = vcat([rand(dist, N)' for dist in d.v]...)

function run_tests(N::Int64)
    v_dists = [Exponential(1.0), Normal(0.0, 1.0), LogNormal(0.0, 1.0)]
    d = product_distribution(v_dists)
    display("Benchmarking single draws")
    display(@benchmark rand($d))
    display(@benchmark rand_quick($d))
    display(@benchmark rand_quick2($d))
    display("Benchmarking samples of size $N")
    display(@benchmark rand($d, $N))
    display(@benchmark rand_quick($d, $N))
    display(@benchmark rand_quick2($d, $N))
    return nothing
end

run_tests(1_000_000)

gives

"Benchmarking single draws"
BenchmarkTools.Trial: 10000 samples with 836 evaluations per sample.
 Range (min … max):  143.840 ns …  67.004 μs  ┊ GC (min … max): 0.00% … 99.73%
 Time  (median):     158.044 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   168.521 ns ± 669.044 ns  ┊ GC (mean ± σ):  4.36% ±  2.43%

       ▄▇▇█▇▃▂▁ ▁▁▃▄▄                                            
  ▁▂▃▆▇███████████████▆▄▄▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▁▂▁▁▁▁▁▁▁▁▁ ▃
  144 ns           Histogram: frequency by time          210 ns <

 Memory estimate: 128 bytes, allocs estimate: 4.
BenchmarkTools.Trial: 10000 samples with 858 evaluations per sample.
 Range (min … max):  136.995 ns …  69.691 μs  ┊ GC (min … max): 0.00% … 99.72%
 Time  (median):     151.759 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   170.850 ns ± 724.923 ns  ┊ GC (mean ± σ):  4.52% ±  2.61%

     █▆▄▁ ▂▆▁                                                    
  ▁▂█████▇███▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  137 ns           Histogram: frequency by time          253 ns <

 Memory estimate: 128 bytes, allocs estimate: 4.
BenchmarkTools.Trial: 10000 samples with 545 evaluations per sample.
 Range (min … max):  209.939 ns …  96.199 μs  ┊ GC (min … max): 0.00% … 99.72%
 Time  (median):     224.007 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   238.015 ns ± 960.152 ns  ┊ GC (mean ± σ):  4.21% ±  1.78%

   ▁▄▃▅▇▇▇▇███▇▆▆▆▆▄▃▃▃▃▂▂▂▁▁▁▁▁▁▁▁▁      ▁                     ▃
  ▆████████████████████████████████████████████▇▇█▇▇▇▇▇█▇▆▆▇▆▇▆ █
  210 ns        Histogram: log(frequency) by time        290 ns <

 Memory estimate: 160 bytes, allocs estimate: 6.
"Benchmarking samples of size 1000000"
BenchmarkTools.Trial: 31 samples with 1 evaluation per sample.
 Range (min … max):  149.399 ms … 194.403 ms  ┊ GC (min … max): 1.35% … 25.34%
 Time  (median):     156.147 ms               ┊ GC (median):    4.97%
 Time  (mean ± σ):   161.630 ms ±   9.801 ms  ┊ GC (mean ± σ):  9.54% ±  5.40%

    ▁▄  ▄▄▁              ▄ ▁▁ █                                  
  ▆▁██▆▁███▆▁▁▁▁▁▁▁▁▁▁▁▆▁█▆██▆█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆ ▁
  149 ms           Histogram: frequency by time          194 ms <

 Memory estimate: 114.43 MiB, allocs estimate: 5999491.
BenchmarkTools.Trial: 323 samples with 1 evaluation per sample.
 Range (min … max):  14.128 ms … 63.315 ms  ┊ GC (min … max): 4.72% … 77.16%
 Time  (median):     15.284 ms              ┊ GC (median):    6.86%
 Time  (mean ± σ):   15.520 ms ±  2.754 ms  ┊ GC (mean ± σ):  7.90% ±  4.68%

        ▁▁ ▁▄▃▃▄▄▄▄▄▇▃█▁▅▄▁▄ ▃▆ ▁                              
  ▃▁▁▅▃▇██▆████████████████████▇█▅▆▅█▄▃▆▆▃▁▃▄▅▁▁▃▆▃▄▃▄▁▃▁▃▃▃▃ ▅
  14.1 ms         Histogram: frequency by time        17.4 ms <

 Memory estimate: 68.67 MiB, allocs estimate: 22.
BenchmarkTools.Trial: 345 samples with 1 evaluation per sample.
 Range (min … max):  13.138 ms … 40.207 ms  ┊ GC (min … max): 0.00% … 66.52%
 Time  (median):     14.345 ms              ┊ GC (median):    6.93%
 Time  (mean ± σ):   14.497 ms ±  1.854 ms  ┊ GC (mean ± σ):  7.89% ±  4.86%

                  ▂   ▇█▄▄ ▄▁▇▅▆▇▄▂▂▄ ▂ ▁                      
  ▃▃▁▁▁▁▁▁▁▃▁▃▅▃▄███▇▇███████████████▆█▄█▆▅▃▅▄▃▃▃▃▃▃▃▃▃▁▁▁▁▁▃ ▄
  13.1 ms         Histogram: frequency by time        15.7 ms <

 Memory estimate: 45.78 MiB, allocs estimate: 20.

JADekker avatar Mar 11 '25 13:03 JADekker

A further point (that is also raised in the thread linked above) is that using product_distribution with different types of distributions introduces type instability. Is there a way in Distributions.jl to avoid this?

JADekker avatar Mar 11 '25 13:03 JADekker

However, my naive implementations (rand_quick, rand_quick2) beat the default implementation by roughly a factor 10 on my device and cut allocations to a constant amount in the matrix case.

In general, the output type may not consist of Float64s. Note also that Distributions.Product is deprecated, which probably explains why it hasn't been developed further lately.

Is there a way in Distributions.jl to avoid this?

~~No, not right now. Distributions.ProductDistribution (and the deprecated Distributions.Product) operates on vectors of distributions, by design this creates type instabilities for vectors with different types of distributions.~~

In principle, product_distribution supports tuples, you can e.g. define

d = product_distribution(Exponential(1.0), Normal(0.0, 1.0), LogNormal(0.0, 1.0))

This can also be inferred. However, probably not every method is optimized for tuples.

devmotion avatar Mar 11 '25 14:03 devmotion

The tuple version seems to be faster (as expected):

julia> using Distributions, BenchmarkTools

julia> d = product_distribution(Exponential(1.0), Normal(0.0, 1.0), LogNormal(0.0, 1.0));

julia> rand(d)
3-element Vector{Float64}:
  1.4515762558532885
 -0.8798585959543993
  0.4173652746533735

julia> @benchmark rand($d)
BenchmarkTools.Trial: 10000 samples with 986 evaluations per sample.
 Range (min … max):  53.331 ns …   2.745 μs  ┊ GC (min … max):  0.00% … 97.27%
 Time  (median):     56.922 ns               ┊ GC (median):     0.00%
 Time  (mean ± σ):   68.167 ns ± 112.689 ns  ┊ GC (mean ± σ):  15.87% ±  9.21%

   ▁█▂      ▇▁
  ▃███▄▄▇▇▅███▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂ ▃
  53.3 ns         Histogram: frequency by time         79.4 ns <

 Memory estimate: 240 bytes, allocs estimate: 6.

devmotion avatar Mar 11 '25 14:03 devmotion

Note also that Distributions.Product is deprecated, which probably explains why it hasn't been developed further lately.

This is why I used d = product_distribution(v_dists) indeed

The tuple version seems to be faster (as expected):

This is for 1 draw. With the tuple version and N = 1,000,000 draws, I get using

function rand_quick(d, N::Int64)
    N_out = Matrix{Float64}(undef, N, length(d.dists))
    for (i, dist) in enumerate(d.dists)
        N_out[:, i] .= rand(dist, N)
    end
    return permutedims(N_out)
end
rand_quick2(d, N::Int64) = vcat([rand(dist, N)' for dist in d.dists]...)

that

    v_dists = [Exponential(1.0), Normal(0.0, 1.0), LogNormal(0.0, 1.0)]
    d_alt = product_distribution(v_dists...)
    display(@benchmark rand($d_alt, $N))
    display(@benchmark rand_quick($d_alt, $N))
    display(@benchmark rand_quick2($d_alt, $N))

produces

BenchmarkTools.Trial: 69 samples with 1 evaluation per sample.
 Range (min … max):  53.807 ms … 129.197 ms  ┊ GC (min … max):  6.04% … 34.83%
 Time  (median):     65.456 ms               ┊ GC (median):     7.91%
 Time  (mean ± σ):   72.510 ms ±  17.873 ms  ┊ GC (mean ± σ):  26.13% ± 16.29%

  █                                                             
  █▄▃▅▄▄▃▃▄▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▇▄▄▅▄▄▁▁▁▄▁▁▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃ ▁
  53.8 ms         Histogram: frequency by time          116 ms <

 Memory estimate: 175.48 MiB, allocs estimate: 4000002.
BenchmarkTools.Trial: 318 samples with 1 evaluation per sample.
 Range (min … max):  14.147 ms … 65.672 ms  ┊ GC (min … max): 3.84% … 77.85%
 Time  (median):     15.344 ms              ┊ GC (median):    7.01%
 Time  (mean ± σ):   15.716 ms ±  3.042 ms  ┊ GC (mean ± σ):  8.29% ±  4.91%

     ▄▄▃▅▇▂▄▇▇▆█▁▃▁                                            
  ▃▄▆██████████████▇▇▆▄▇▄▃▃▃▄▁▄▃▁▁▁▃▁▃▃▃▃▁▁▁▁▃▃▁▁▃▁▁▁▁▁▁▁▁▃▁▄ ▄
  14.1 ms         Histogram: frequency by time        20.4 ms <

 Memory estimate: 68.66 MiB, allocs estimate: 14.
BenchmarkTools.Trial: 343 samples with 1 evaluation per sample.
 Range (min … max):  13.553 ms … 42.706 ms  ┊ GC (min … max): 0.00% … 66.89%
 Time  (median):     14.304 ms              ┊ GC (median):    7.45%
 Time  (mean ± σ):   14.589 ms ±  2.117 ms  ┊ GC (mean ± σ):  8.64% ±  5.12%

   ▄▇█▇▅▇▇▄▃▂▂                                                 
  ▇███████████▆▄▁▁▁▁▄▄▁▁▄▁▁▁▁▁▁▁▁▁▄▁▁▄▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▄ ▇
  13.6 ms      Histogram: log(frequency) by time        22 ms <

 Memory estimate: 45.78 MiB, allocs estimate: 16.

JADekker avatar Mar 11 '25 15:03 JADekker

This is why I used d = product_distribution(v_dists) indeed

product_distribution(::Vector) still creates a Distributions.Product. (Otherwise your benchmarks also wouldn't work since you only define methods for ::Product)

I get using

As I suspected, likely a few methods can be improved for tuples. But note again that the version you're benchmarking against is not correct in general. Splatting a vector should also be avoided in general.

devmotion avatar Mar 11 '25 15:03 devmotion

This is why I used d = product_distribution(v_dists) indeed

product_distribution(::Vector) still creates a Distributions.Product.

I see, I didn't notice that! Thanks!

JADekker avatar Mar 11 '25 15:03 JADekker

As I suspected, likely a few methods can be improved for tuples. But note again that the version you're benchmarking against is not correct in general.

I agree that my case is specialised for Float64, but getting a large speed-up (and large reduction of allocations) on a type that is quite common for random variables may be valuable?

JADekker avatar Mar 11 '25 15:03 JADekker

I think on the master branch we could do e.g.

julia> function Distributions._rand!(rng::Distributions.AbstractRNG, d::Distributions.VectorOfUnivariateDistribution, x::AbstractMatrix{<:Real})
           for (xi, dist) in zip(eachrow(x), d.dists)
               rand!(rng, dist, xi)
           end
           return x
       end

julia> rand(d, 10)
3×10 Matrix{Float64}:
 0.397305   0.155416  1.69355  0.284072  1.55921   2.1799     1.26069   2.30843   0.153868  1.69613
 0.337333  -2.05144   1.00901  1.44668   2.23835   0.0380719  0.270388  2.11597  -0.435672  1.40537
 0.211291   0.933384  2.79485  0.677067  0.465326  1.85827    2.64057   5.44137   3.34289   0.287327

julia> @benchmark rand($d, 1_000_000)
BenchmarkTools.Trial: 302 samples with 1 evaluation per sample.
 Range (min … max):  15.541 ms …  27.737 ms  ┊ GC (min … max): 0.00% … 25.28%
 Time  (median):     16.320 ms               ┊ GC (median):    3.48%
 Time  (mean ± σ):   16.558 ms ± 986.985 μs  ┊ GC (mean ± σ):  3.52% ±  1.92%

      ▂▁  ▄▄▇██▇▄▄▄▂▁    ▁
  ▅▅▇███▅▅████████████▆▅▅█▆▇▅▅▆▇▆▆▆▇▅▆▅▆▇▁▁▁▁▅▁▁▅▆▆▁▁▁▁▁▅▁▁▁▁▅ ▇
  15.5 ms       Histogram: log(frequency) by time      19.3 ms <

 Memory estimate: 22.89 MiB, allocs estimate: 7.

This will have to be changed in #1905, of course.

devmotion avatar Mar 11 '25 15:03 devmotion