Sampling from Distributions.product_distribution allocates (a lot)
Hi, I noticed that when sampling a matrix of values from a Distributions.product_distribution, a lot of allocations are made (in my example, more allocations than sampled values!) if I call rand(..., N) directly on a product_distributions with different distributions. However, my naive implementations (rand_quick, rand_quick2) beat the default implementation by roughly a factor 10 on my device and cut allocations to a constant amount in the matrix case. rand_quick also provides a modest speed-up on my device in the vector case. Am I misusing product_distribution here, or is there room for improvement here? See this discourse thread for my initial question that led to this issue: https://discourse.julialang.org/t/product-distribution-allocates-a-lot/126771/2.
using Distributions, BenchmarkTools, Random
Random.seed!(42)
function rand_quick(d::Product)
N_out = Vector{Float64}(undef, length(d.v))
for (i, dist) in enumerate(d.v)
N_out[i] = rand(dist)
end
return N_out
end
rand_quick2(d::Product) = [rand(dist) for dist in d.v]
function rand_quick(d::Product, N::Int64)
N_out = Matrix{Float64}(undef, N, length(d.v))
for (i, dist) in enumerate(d.v)
N_out[:, i] .= rand(dist, N)
end
return permutedims(N_out)
end
rand_quick2(d::Product, N::Int64) = vcat([rand(dist, N)' for dist in d.v]...)
function run_tests(N::Int64)
v_dists = [Exponential(1.0), Normal(0.0, 1.0), LogNormal(0.0, 1.0)]
d = product_distribution(v_dists)
display("Benchmarking single draws")
display(@benchmark rand($d))
display(@benchmark rand_quick($d))
display(@benchmark rand_quick2($d))
display("Benchmarking samples of size $N")
display(@benchmark rand($d, $N))
display(@benchmark rand_quick($d, $N))
display(@benchmark rand_quick2($d, $N))
return nothing
end
run_tests(1_000_000)
gives
"Benchmarking single draws"
BenchmarkTools.Trial: 10000 samples with 836 evaluations per sample.
Range (min … max): 143.840 ns … 67.004 μs ┊ GC (min … max): 0.00% … 99.73%
Time (median): 158.044 ns ┊ GC (median): 0.00%
Time (mean ± σ): 168.521 ns ± 669.044 ns ┊ GC (mean ± σ): 4.36% ± 2.43%
▄▇▇█▇▃▂▁ ▁▁▃▄▄
▁▂▃▆▇███████████████▆▄▄▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▁▂▁▁▁▁▁▁▁▁▁ ▃
144 ns Histogram: frequency by time 210 ns <
Memory estimate: 128 bytes, allocs estimate: 4.
BenchmarkTools.Trial: 10000 samples with 858 evaluations per sample.
Range (min … max): 136.995 ns … 69.691 μs ┊ GC (min … max): 0.00% … 99.72%
Time (median): 151.759 ns ┊ GC (median): 0.00%
Time (mean ± σ): 170.850 ns ± 724.923 ns ┊ GC (mean ± σ): 4.52% ± 2.61%
█▆▄▁ ▂▆▁
▁▂█████▇███▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
137 ns Histogram: frequency by time 253 ns <
Memory estimate: 128 bytes, allocs estimate: 4.
BenchmarkTools.Trial: 10000 samples with 545 evaluations per sample.
Range (min … max): 209.939 ns … 96.199 μs ┊ GC (min … max): 0.00% … 99.72%
Time (median): 224.007 ns ┊ GC (median): 0.00%
Time (mean ± σ): 238.015 ns ± 960.152 ns ┊ GC (mean ± σ): 4.21% ± 1.78%
▁▄▃▅▇▇▇▇███▇▆▆▆▆▄▃▃▃▃▂▂▂▁▁▁▁▁▁▁▁▁ ▁ ▃
▆████████████████████████████████████████████▇▇█▇▇▇▇▇█▇▆▆▇▆▇▆ █
210 ns Histogram: log(frequency) by time 290 ns <
Memory estimate: 160 bytes, allocs estimate: 6.
"Benchmarking samples of size 1000000"
BenchmarkTools.Trial: 31 samples with 1 evaluation per sample.
Range (min … max): 149.399 ms … 194.403 ms ┊ GC (min … max): 1.35% … 25.34%
Time (median): 156.147 ms ┊ GC (median): 4.97%
Time (mean ± σ): 161.630 ms ± 9.801 ms ┊ GC (mean ± σ): 9.54% ± 5.40%
▁▄ ▄▄▁ ▄ ▁▁ █
▆▁██▆▁███▆▁▁▁▁▁▁▁▁▁▁▁▆▁█▆██▆█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆ ▁
149 ms Histogram: frequency by time 194 ms <
Memory estimate: 114.43 MiB, allocs estimate: 5999491.
BenchmarkTools.Trial: 323 samples with 1 evaluation per sample.
Range (min … max): 14.128 ms … 63.315 ms ┊ GC (min … max): 4.72% … 77.16%
Time (median): 15.284 ms ┊ GC (median): 6.86%
Time (mean ± σ): 15.520 ms ± 2.754 ms ┊ GC (mean ± σ): 7.90% ± 4.68%
▁▁ ▁▄▃▃▄▄▄▄▄▇▃█▁▅▄▁▄ ▃▆ ▁
▃▁▁▅▃▇██▆████████████████████▇█▅▆▅█▄▃▆▆▃▁▃▄▅▁▁▃▆▃▄▃▄▁▃▁▃▃▃▃ ▅
14.1 ms Histogram: frequency by time 17.4 ms <
Memory estimate: 68.67 MiB, allocs estimate: 22.
BenchmarkTools.Trial: 345 samples with 1 evaluation per sample.
Range (min … max): 13.138 ms … 40.207 ms ┊ GC (min … max): 0.00% … 66.52%
Time (median): 14.345 ms ┊ GC (median): 6.93%
Time (mean ± σ): 14.497 ms ± 1.854 ms ┊ GC (mean ± σ): 7.89% ± 4.86%
▂ ▇█▄▄ ▄▁▇▅▆▇▄▂▂▄ ▂ ▁
▃▃▁▁▁▁▁▁▁▃▁▃▅▃▄███▇▇███████████████▆█▄█▆▅▃▅▄▃▃▃▃▃▃▃▃▃▁▁▁▁▁▃ ▄
13.1 ms Histogram: frequency by time 15.7 ms <
Memory estimate: 45.78 MiB, allocs estimate: 20.
A further point (that is also raised in the thread linked above) is that using product_distribution with different types of distributions introduces type instability. Is there a way in Distributions.jl to avoid this?
However, my naive implementations (rand_quick, rand_quick2) beat the default implementation by roughly a factor 10 on my device and cut allocations to a constant amount in the matrix case.
In general, the output type may not consist of Float64s. Note also that Distributions.Product is deprecated, which probably explains why it hasn't been developed further lately.
Is there a way in Distributions.jl to avoid this?
~~No, not right now. Distributions.ProductDistribution (and the deprecated Distributions.Product) operates on vectors of distributions, by design this creates type instabilities for vectors with different types of distributions.~~
In principle, product_distribution supports tuples, you can e.g. define
d = product_distribution(Exponential(1.0), Normal(0.0, 1.0), LogNormal(0.0, 1.0))
This can also be inferred. However, probably not every method is optimized for tuples.
The tuple version seems to be faster (as expected):
julia> using Distributions, BenchmarkTools
julia> d = product_distribution(Exponential(1.0), Normal(0.0, 1.0), LogNormal(0.0, 1.0));
julia> rand(d)
3-element Vector{Float64}:
1.4515762558532885
-0.8798585959543993
0.4173652746533735
julia> @benchmark rand($d)
BenchmarkTools.Trial: 10000 samples with 986 evaluations per sample.
Range (min … max): 53.331 ns … 2.745 μs ┊ GC (min … max): 0.00% … 97.27%
Time (median): 56.922 ns ┊ GC (median): 0.00%
Time (mean ± σ): 68.167 ns ± 112.689 ns ┊ GC (mean ± σ): 15.87% ± 9.21%
▁█▂ ▇▁
▃███▄▄▇▇▅███▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂ ▃
53.3 ns Histogram: frequency by time 79.4 ns <
Memory estimate: 240 bytes, allocs estimate: 6.
Note also that Distributions.Product is deprecated, which probably explains why it hasn't been developed further lately.
This is why I used d = product_distribution(v_dists) indeed
The tuple version seems to be faster (as expected):
This is for 1 draw. With the tuple version and N = 1,000,000 draws, I get using
function rand_quick(d, N::Int64)
N_out = Matrix{Float64}(undef, N, length(d.dists))
for (i, dist) in enumerate(d.dists)
N_out[:, i] .= rand(dist, N)
end
return permutedims(N_out)
end
rand_quick2(d, N::Int64) = vcat([rand(dist, N)' for dist in d.dists]...)
that
v_dists = [Exponential(1.0), Normal(0.0, 1.0), LogNormal(0.0, 1.0)]
d_alt = product_distribution(v_dists...)
display(@benchmark rand($d_alt, $N))
display(@benchmark rand_quick($d_alt, $N))
display(@benchmark rand_quick2($d_alt, $N))
produces
BenchmarkTools.Trial: 69 samples with 1 evaluation per sample.
Range (min … max): 53.807 ms … 129.197 ms ┊ GC (min … max): 6.04% … 34.83%
Time (median): 65.456 ms ┊ GC (median): 7.91%
Time (mean ± σ): 72.510 ms ± 17.873 ms ┊ GC (mean ± σ): 26.13% ± 16.29%
█
█▄▃▅▄▄▃▃▄▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▇▄▄▅▄▄▁▁▁▄▁▁▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃ ▁
53.8 ms Histogram: frequency by time 116 ms <
Memory estimate: 175.48 MiB, allocs estimate: 4000002.
BenchmarkTools.Trial: 318 samples with 1 evaluation per sample.
Range (min … max): 14.147 ms … 65.672 ms ┊ GC (min … max): 3.84% … 77.85%
Time (median): 15.344 ms ┊ GC (median): 7.01%
Time (mean ± σ): 15.716 ms ± 3.042 ms ┊ GC (mean ± σ): 8.29% ± 4.91%
▄▄▃▅▇▂▄▇▇▆█▁▃▁
▃▄▆██████████████▇▇▆▄▇▄▃▃▃▄▁▄▃▁▁▁▃▁▃▃▃▃▁▁▁▁▃▃▁▁▃▁▁▁▁▁▁▁▁▃▁▄ ▄
14.1 ms Histogram: frequency by time 20.4 ms <
Memory estimate: 68.66 MiB, allocs estimate: 14.
BenchmarkTools.Trial: 343 samples with 1 evaluation per sample.
Range (min … max): 13.553 ms … 42.706 ms ┊ GC (min … max): 0.00% … 66.89%
Time (median): 14.304 ms ┊ GC (median): 7.45%
Time (mean ± σ): 14.589 ms ± 2.117 ms ┊ GC (mean ± σ): 8.64% ± 5.12%
▄▇█▇▅▇▇▄▃▂▂
▇███████████▆▄▁▁▁▁▄▄▁▁▄▁▁▁▁▁▁▁▁▁▄▁▁▄▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▄ ▇
13.6 ms Histogram: log(frequency) by time 22 ms <
Memory estimate: 45.78 MiB, allocs estimate: 16.
This is why I used d = product_distribution(v_dists) indeed
product_distribution(::Vector) still creates a Distributions.Product. (Otherwise your benchmarks also wouldn't work since you only define methods for ::Product)
I get using
As I suspected, likely a few methods can be improved for tuples. But note again that the version you're benchmarking against is not correct in general. Splatting a vector should also be avoided in general.
This is why I used d = product_distribution(v_dists) indeed
product_distribution(::Vector)still creates aDistributions.Product.
I see, I didn't notice that! Thanks!
As I suspected, likely a few methods can be improved for tuples. But note again that the version you're benchmarking against is not correct in general.
I agree that my case is specialised for Float64, but getting a large speed-up (and large reduction of allocations) on a type that is quite common for random variables may be valuable?
I think on the master branch we could do e.g.
julia> function Distributions._rand!(rng::Distributions.AbstractRNG, d::Distributions.VectorOfUnivariateDistribution, x::AbstractMatrix{<:Real})
for (xi, dist) in zip(eachrow(x), d.dists)
rand!(rng, dist, xi)
end
return x
end
julia> rand(d, 10)
3×10 Matrix{Float64}:
0.397305 0.155416 1.69355 0.284072 1.55921 2.1799 1.26069 2.30843 0.153868 1.69613
0.337333 -2.05144 1.00901 1.44668 2.23835 0.0380719 0.270388 2.11597 -0.435672 1.40537
0.211291 0.933384 2.79485 0.677067 0.465326 1.85827 2.64057 5.44137 3.34289 0.287327
julia> @benchmark rand($d, 1_000_000)
BenchmarkTools.Trial: 302 samples with 1 evaluation per sample.
Range (min … max): 15.541 ms … 27.737 ms ┊ GC (min … max): 0.00% … 25.28%
Time (median): 16.320 ms ┊ GC (median): 3.48%
Time (mean ± σ): 16.558 ms ± 986.985 μs ┊ GC (mean ± σ): 3.52% ± 1.92%
▂▁ ▄▄▇██▇▄▄▄▂▁ ▁
▅▅▇███▅▅████████████▆▅▅█▆▇▅▅▆▇▆▆▆▇▅▆▅▆▇▁▁▁▁▅▁▁▅▆▆▁▁▁▁▁▅▁▁▁▁▅ ▇
15.5 ms Histogram: log(frequency) by time 19.3 ms <
Memory estimate: 22.89 MiB, allocs estimate: 7.
This will have to be changed in #1905, of course.