Distributions.jl
Distributions.jl copied to clipboard
Fix Dirichlet rand overflows #1702
Closes #1702
Core Issues
The rand(d::Dirichlet)
calls Gamma(d.α[i])
i times and writes to x
.
It then rescales this result by inv(sum(x))
. When this overflows to Inf
, we run into our 2 failure modes:
-
When all x_i == 0, we get Inf * 0 = NaN
-
When some x_i != 0, but are all deeply subnormal enough that
inv(sum(x))
still overflows. We get some Inf values as a result.
For case 2, on Julia 1.11.0-rc1 on Windows, for example:
julia> rand(Xoshiro(123322), Dirichlet([4.5e-5, 4.5e-5, 8e-5]))
3-element Vector{Float64}:
Inf
Inf
NaN
Fixing Case 1
If case 1 is happening, the best thing possible from a runtime perspective is probably to just choose a random x from a categorical distribution with the same mean. This is the limit behavior of the Dirichlet distribution, and my logic on why it's "safe enough" is:
- If all-zeros are a rare occurance, this has little impact on the end sample
- If all-zeros are common, rejecting samples and pulling another will probably yield a near-infinite reject loop. On the other hand, we're close enough to the limit behavior that floating point arithmetic errors are probably hurting us more than adopting the limit behavior.
- While this should theoretically result in incorrect variance, testing shows that variance is within reasonable tolerance (0.01) of the real value.
There is another option where we could try rejecting all-0 samples until a certain maximum amount of samples before failing, but I think this is probably a waste of time for little gain in accuracy.
Fixing Case 2
We rescale all values by multiplying them by floatmax(), so inv
doesn't overflow. This should work consistently for all float types where floatmax() * nextfloat() > floatmin()
by at least ~1 magnitudes, which I think should be true for any non-exotic float types. I originally thought it would be enough to just set the largest value to 1, but it's actually possible to currently pull multiple subnormal values pre-normalization, and the method I adopted maintains the ratio between them.
Currently:
julia> rand(Xoshiro(123322), Dirichlet([4.5e-5, 4.5e-5, 8e-5]))
3-element Vector{Float64}:
Inf
Inf
NaN
After this patch:
julia> rand(Xoshiro(123322), Dirichlet([4.5e-5, 4.5e-5, 8e-5]))
3-element Vector{Float64}:
0.625061099164708
0.37493890083529186
0.0
Subnormal Parameters
While testing, I realized that my original fix for case 1 would break when all of the parameters themselves were deeply subnormal, e.g. Dirichlet([5e-321, 1e-321, 4e-321])
. Given that the Dirichlet distribution is decently common in things like Bayesian inference, I thought it would be worth attempting to support these cases too.
Note that mean
, var
, etc. currently break on these deeply subnormally-parameterized distributions, but fixing that felt out of scope to this pull request. Fixing mean
would be simple, but it could potentially be rather chunky. I am less sure about var
and others.