StochasticAD.jl icon indicating copy to clipboard operation
StochasticAD.jl copied to clipboard

Make propagate more like a monadic bind by supporting stochastic triple creating functions

Open gaurav-arya opened this issue 6 months ago • 6 comments

x-ref #128. @GuusAvis example:

using StochasticAD
using Distributions

function f(value_1, value_2, rand_var)
    if value_1 < value_2
        return (value_1 + rand(rand_var), value_2)
    else
        return (value_1, value_2 + rand(rand_var))
    end
end

propagate_f(value_1, value_2, rand_var) = StochasticAD.propagate((v1, v2) -> f(v1, v2, rand_var), value_1, value_2)

f(value_1::StochasticTriple, value_2, rand_var) = propagate_f(value_1, value_2, rand_var)
f(value_1, value_2::StochasticTriple, rand_var) = propagate_f(value_1, value_2, rand_var)
f(value_1::StochasticTriple, value_2::StochasticTriple, rand_var) = propagate_f(value_1, value_2, rand_var)

function g(p)
    rand_var = Bernoulli(p)
    value_1 = 0
    value_2 = 2
    for i in 1:10
        value_1, value_2 = f(value_1, value_2, rand_var)
    end
    return value_1, value_2
end

@show g(0.5)
@show mean((sum(g(0.6)) - sum(g(0.5))) / 0.1 for i in 1:1000) # 9.59
@show mean(derivative_estimate(p -> sum(g(p)), 0.5) for i in 1:100) # 8.84

@GuusAvis let me know if you have any issues, and if things work out adding the above as a test to triples.jl would be most welcome:)

gaurav-arya avatar Aug 03 '24 01:08 gaurav-arya