Omega.jl
Omega.jl copied to clipboard
Implement `grad(x, ω, ZygoteGrad)`
Given some real valued random variable, we want to be able to compute the gradients.
Gradients are in the package OmegaGrad, which is not within OmegaCore, so you'll need to dev it
There are many different ways to represent the gradient. These are expressed in the different interfaces in OmegaGrad.jl. The one we'll focus on on first is grad
x = 1 ~ Normal(0, 1)
ω = defω()
x_ = x(ω)
ωgrad = grad(x, ω, ZygoteGrad)
ωgrad should be a ω::AbstractΩ but whose components now reflect the gradient values.