StochasticAD.jl
StochasticAD.jl copied to clipboard
Stochastic triple `getindex` rule does not support vector-valued array elements
Hi, It appears to me that there is no support for higher dimensional random walks. Is this an inherent property of the approach or is there a method of making this work?
A minimal example based on the random walk example:
using Distributions # defines several supported discrete distributions
using StochasticAD
using LinearAlgebra
function simulate_walk(probs, steps, n)
state = [0, 0]
for i in 1:n
probs_here = probs(state) # transition probabilities for possible steps
step_index = rand(Categorical(probs_here)) # which step do we take?
step = steps[step_index] # Error happens here
state += step
end
return norm(state)
end
steps = [[0, 1],[0,-1], [1, 0], [-1, 0]] # move in any direction
make_probs(p) = X -> [1 - exp(-norm(X) / norm(p)), exp(-norm(X) / norm(p)), 0, 0]
f(p, n) = simulate_walk(make_probs(p), steps, n)
@show f(50, 100) # let's run a single random walk with p = 50
@show stochastic_triple(p -> f(p, 100), 50) # let's see how a single stochastic triple looks like at p = 50
f_squared(p, n) = f(p, n)^2
samples = [derivative_estimate(p -> f_squared(p, 100), 50) for i in 1:1000] # many samples from derivative program at p = 50
derivative = mean(samples)
uncertainty = std(samples) / sqrt(1000)
println("derivative of 𝔼[f_squared] = $derivative ± $uncertainty")
results in
ERROR: LoadError: MethodError: no method matching value(::Vector{Int64})
I suppose the Base.getindex
method in general_rules.jl
could be modified to work with a vector of vectors but I haven't yet figured out how. Or could there be a workaround using propagate
? Probably related to #117.
Thanks!
Hi! Indeed this is not a fundamental limitation, just a lack of generality in our getindex
rule.
Let's leave this issue open until I generalize the getindex
rule, but you should indeed be able to work around it via propagate
: see https://gaurav-arya.github.io/StochasticAD.jl/stable/devdocs.html#StochasticAD.propagate. For your particular case, another simpler solution could be to replace steps[step_index]
with a map over each dimension such that only the scalar get-index is used.
Sounds good. I had problems with using propagate
, but using map to convert to a scalar index works perfectly. Thanks for the suggestion!