Distributions.jl icon indicating copy to clipboard operation
Distributions.jl copied to clipboard

Better samplers for Bernoulli(p) and Geometric(1/2) random variables

Open abigail-gentle opened this issue 1 year ago • 3 comments

The implementation for generating a Bernoulli(p) random variable right now is rand() <= p. This is totally fine, but there are better ways to handle rational probabilities, and there is a method detailed here which uses less entropy (on expectation, 2 bits).

For rationals the obvious solution is rand(1:denominator) <= numerator, which is faster on my machine, possibly due to the expense of converting the rational prob to a float for comparison.

I also have code for the other sampler, but on my machine it's about 2x as slow as the current implementation, still pretty fast, 6.4ns vs the old method 2.7ns.

In writing that code I needed to sample a Geometric(1/2) distribution, and found that the current implementation is needlessly expensive for this default case. It is faster to generate random bits (sizeof(Int) at a time) until you see a 1, this could go down to generating random Int8's as well if minimising entropy usage was a concern (but the default sampler generates 64 bits at a time and then truncates anyways)

I am interested in implementing these, but want to know whether it is of interest. I also wouldn't be sure how to implement a sampler for a "default" geometric rv that doesn't meddle with the current implementation either.

Posting here to get thoughts, I think the method for rational Bernoulli's should absolutely be implemented as a strict improvement, but otherwise I want to know what people think is important.

abigail-gentle avatar Aug 09 '24 08:08 abigail-gentle

Not a maintainer, so don't take my position as canonical.

There's a lot of optimization that can be done with the rand and rand! functions.

Personally, I would also recommend defining explicit rand! methods alongside scalar rand methods. Very few distributions currently do this, but it can yield significant speedups when implemented.

The core reason behind the performance increases is because Random.jl has SIMD implementations for rand! and randn! for both Xoshiro and MersenneTwister. Calling rand(1) 100 times will be significantly slower than calling rand(100) once.


Don't worry too much if you replace an existing method with something faster that yields different results (but is still within expectations for the distribution). Distributions.jl doesn't guarantee that rand(::AbstractRNG, ::Distribution) will yield the same result across versions with the same seed, and neither does Random.jl.

I'm curious about the other method for sampling now; I'm tempted to try an implementation of my own to see if I can get it faster than what we currently have.

quildtide avatar Aug 09 '24 17:08 quildtide

Really like the suggestion to implement rand!, I have personally had slow-downs from generating lots of Bernoulli samples in my own work, I'll have to look at what I can do to make this work, but sounds fun.

As far as my own implementation the details are as follows,

benchmarks for my version on a Mac M1 laptop: image Which shows comparable performance (min time), but worse average time (does a lot more operations, which is pretty apparent looking at the built code) Where this is the code I used:

function my_bernoulli_abbreviated(d::Bernoulli{T})::Bool where T <: Union{Float16, Float32, Float64}

    isone(d.p) && return true
    
    significand_bits::UInt = precision(T) - 1
    max_coin_flips::UInt = Base.exponent_max(T) + significand_bits
    leading_zs::UInt = Base.exponent_max(T) - 1 - Base.unbiased_exponent(d.p)

    first_heads_index::UInt = begin
        coin_flips::UInt = typemax(UInt)
        chunk_size::UInt = 8 * sizeof(UInt)
        num_chunks::UInt = cld(max_coin_flips, chunk_size)
        for i in 1:num_chunks
            geo = leading_zeros(rand(UInt))
            if geo != chunk_size
                coin_flips = (i-1)*chunk_size + geo
                break
            end
        end
        coin_flips
    end
    first_heads_index == typemax(UInt) && return false # Truncated geometric
    first_heads_index < leading_zs && return false # Cannot be 1
    first_heads_index == leading_zs && return !iszero(reinterpret(Unsigned, d.p) & Base.exponent_mask(T)) # implicitly 1, for normal floats
    first_heads_index > leading_zs + significand_bits && return false # beyond the 1's, implicitly 0
    !iszero(reinterpret(Unsigned, d.p) & (1 << (leading_zs + significand_bits - first_heads_index))) # index the i'th value of the significand
end

If you'd like to see a documented version I can point you to this well written implementation which I used as a reference https://docs.rs/opendp/latest/src/opendp/traits/samplers/bernoulli/mod.rs.html#25-137

I found this post which promises to be pretty efficient, but I don't understand the method right now https://discourse.julialang.org/t/faster-bernoulli-sampling/35209

abigail-gentle avatar Aug 12 '24 05:08 abigail-gentle

A special branch for Geometric(0.5) was added in #1934.

devmotion avatar Jan 17 '25 15:01 devmotion