Omega.jl icon indicating copy to clipboard operation
Omega.jl copied to clipboard

Implement `grad(x, ω, ZygoteGrad)`

Open zenna opened this issue 5 years ago • 0 comments

Given some real valued random variable, we want to be able to compute the gradients.

Gradients are in the package OmegaGrad, which is not within OmegaCore, so you'll need to dev it

There are many different ways to represent the gradient. These are expressed in the different interfaces in OmegaGrad.jl. The one we'll focus on on first is grad

x = 1 ~ Normal(0, 1)
ω = defω()
x_ = x(ω)
ωgrad = grad(x, ω, ZygoteGrad)

ωgrad should be a ω::AbstractΩ but whose components now reflect the gradient values.

zenna avatar Jul 15 '20 14:07 zenna