Categorical distribution
Partially solves #145. I did not add alternative (log-p) parametrizations yet.
I think this one is a typical example of the limits of our approach, which was to default on Distributions.jl for sampling and other things whenever possible. Indeed, their implementation of Categorical relies on DiscreteNonParametric, which is a nightmare in terms of efficiency (see here for the worst part).
In my opinion, this boils down to their library being "safe" whereas we over at MeasureTheory.jl choose to trust that the user will avoid nonsensical inputs.
Codecov Report
Merging #148 (2f7e03c) into master (419d2f2) will decrease coverage by
8.79%. The diff coverage is0.00%.
@@ Coverage Diff @@
## master #148 +/- ##
==========================================
- Coverage 42.06% 33.27% -8.80%
==========================================
Files 50 30 -20
Lines 1034 568 -466
==========================================
- Hits 435 189 -246
+ Misses 599 379 -220
| Impacted Files | Coverage Δ | |
|---|---|---|
| src/parameterized/categorical.jl | 0.00% <0.00%> (ø) |
|
| src/distproxy.jl | 50.00% <0.00%> (-50.00%) |
:arrow_down: |
| src/parameterized/normal.jl | 17.24% <0.00%> (-40.46%) |
:arrow_down: |
| src/parameterized/cauchy.jl | 14.28% <0.00%> (-28.58%) |
:arrow_down: |
| src/parameterized/studentt.jl | 33.33% <0.00%> (-26.67%) |
:arrow_down: |
| src/parameterized/laplace.jl | 9.09% <0.00%> (-24.25%) |
:arrow_down: |
| src/combinators/product.jl | 0.00% <0.00%> (-22.48%) |
:arrow_down: |
| src/combinators/weighted.jl | 0.00% <0.00%> (-20.69%) |
:arrow_down: |
| src/parameterized/gumbel.jl | 12.50% <0.00%> (-16.08%) |
:arrow_down: |
| src/const.jl | 26.66% <0.00%> (-10.84%) |
:arrow_down: |
| ... and 26 more |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact),ø = not affected,? = missing dataPowered by Codecov. Last update 419d2f2...2f7e03c. Read the comment docs.
Thanks @gdalle for pushing on this, it's great to have it moving forward. Yes, I think we can get things much more efficient, in a few ways. First, in the logp parameterization, ...
- the
logdensitywill just begetindex - we can use the Gumbel-max trick for sampling
- We can avoid normalization, leaving it to depend on
logsumexp(logp)
Then, in the case where the user passes something already normalized, it will turn out independent of the measure.
Clearly we also need the unlogged parameterization, and I think we can do much better in that case as well.
It's very strange to me that they sortperm(vs). If anything, I would expect sorting the probabilities in reverse order, so when you walk down them to sample you can stop earlier. Anyway, I think we can avoid this, or when we do have to sort we can lean on PermutedArrays:
https://github.com/cscherrer/PermutedArrays.jl
I have a whole plan for using this to build a better DiscreteNonparametric based around a PriorityQueue, but implemented more efficiently than that. It will use Dictionaries.jl. That's a whole separate issue, so we can discuss separately if you like :)
Also I should maybe point out that the whole distproxy thing isn't core to our approach. It's a stand-in, and when it works well we can keep it. But the great thing is, we can break away as needed to get better flexibility and/or performance
I think the sortperm is useful for quick cdf computations, because then you only need a searchsortedfirst(vs, x) to know at which step of the discrete distribution you are.
Hi @gdalle , how's this going? Let me know if you get stuck or have questions, or when it's ready for review :)
Hey @cscherrer,
Just added the logp parametrization, and realized the sortperm may not be a worry since the values are a range type: judging by a little @benchmark, sorting p doesn't seem to reallocate.
However, in the log parametrization, the sampling is very inefficient since it requires computing the exponential of logp. Is that what you wanted me to use the Gumbel trick for? I don't know what it is just yet, but I will look it up.
If you have a vector logp of log-probabilities, it's just (pseudo-Julia)
draw_categorical(logp) = argmax(logp .- log(-log.(rand(length(logp))))
The second term is a vector of Gumbel samples.
The trade-offs are
- you pay two logs instead of one exp
- but after that the only dependency is the argmax, so it's a
reduceand should scale well - staying in log space avoids underflow
- maybe the biggest thing, there's no need to normalize anything
Hey! Is this PR blocked or is there something manageable which still needs to be tackled?
Honestly I haven't touched this in years so I better let @cscherrer answer
Hi! Any news on this?
Not on my end, my work no longer requires MeasureTheory so I don't think I'll pursue this further
Feel free to take it over though