StochasticAD.jl
StochasticAD.jl copied to clipboard
Make propagate more like a monadic bind by supporting stochastic triple creating functions
x-ref #128. @GuusAvis example:
using StochasticAD
using Distributions
function f(value_1, value_2, rand_var)
if value_1 < value_2
return (value_1 + rand(rand_var), value_2)
else
return (value_1, value_2 + rand(rand_var))
end
end
propagate_f(value_1, value_2, rand_var) = StochasticAD.propagate((v1, v2) -> f(v1, v2, rand_var), value_1, value_2)
f(value_1::StochasticTriple, value_2, rand_var) = propagate_f(value_1, value_2, rand_var)
f(value_1, value_2::StochasticTriple, rand_var) = propagate_f(value_1, value_2, rand_var)
f(value_1::StochasticTriple, value_2::StochasticTriple, rand_var) = propagate_f(value_1, value_2, rand_var)
function g(p)
rand_var = Bernoulli(p)
value_1 = 0
value_2 = 2
for i in 1:10
value_1, value_2 = f(value_1, value_2, rand_var)
end
return value_1, value_2
end
@show g(0.5)
@show mean((sum(g(0.6)) - sum(g(0.5))) / 0.1 for i in 1:1000) # 9.59
@show mean(derivative_estimate(p -> sum(g(p)), 0.5) for i in 1:100) # 8.84
@GuusAvis let me know if you have any issues, and if things work out adding the above as a test to triples.jl
would be most welcome:)