StochasticAD.jl
StochasticAD.jl copied to clipboard
PMF of Bernoullis
Hey!
I am trying to do VI with models involving discrete RVs. For that purpose it would be quite handy to get derivative estimators for PMFs of Bernoulli (and other discrete) RVs. Considering the following toy example
using StochasticAD, Distributions
function func(p)
x = rand(Bernoulli(p))
pdf(Bernoulli(p), x)
end
function func_alt(p)
x = rand(Bernoulli(p))
p^x*(1-p)^(1-x)
end
it seems that I can only get derivative estimators for func_alt
. When trying to propagate derivative info through func
it appears to fail because of the way pdf(::Bernoulli, ::Bool/Real)
is implemented.
Now my questions:
- Am I using StochasticAD incorrectly?
- If not, would it be easy to accommodate propagation of stochastic triples through
pdf(::Bernoulli, ::Bool/Real)
(and perhaps the equivalent for other distributions with discrete support where I assume similar problems would arise)?
Thanks! Flemming
The issue is that stochastic triples unfortunately cannot propagate through the ternary operator in Distributions.jl
's implementation of the Bernoulli PMF. You could fix this by overloading Distributions.pdf
to catch stochastic triple inputs and feed these into the experimental (undocumented) StochasticAD.propagate
interface:
using StochasticAD, Distributions
# Register an overload of the pdf
using Functors; @functor Bernoulli
Distributions.pdf(d::Bernoulli, x::StochasticAD.StochasticTriple) = StochasticAD.propagate(pdf, d, x; keep_deltas = Val{true}())
derivative_estimate(func, 0.7)
This should work on the about-to-be-released 0.1.14. (Performance may not be ideal; StochasticAD.propagate
is still experimental functionality.)
Let me know if you have any questions! In any case, let's leave this issue open until this works out of the box.
(Edit: added keep_deltas = Val{true}()
to the propagate call; the previous version was not correct.)